# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import importlib
import os
import subprocess
from functools import partial

import torch
from tensordict import TensorDictBase
from torch import nn
from torchrl._utils import logger as torchrl_logger

from torchrl.data.tensor_specs import Composite, DEVICE_TYPING, TensorSpec, Unbounded
from torchrl.envs.transforms.transforms import (
    CenterCrop,
    Compose,
    ObservationNorm,
    Resize,
    ToTensorImage,
    Transform,
)
from torchrl.envs.transforms.utils import _set_missing_tolerance

_has_vc = importlib.util.find_spec("vc_models") is not None


class VC1Transform(Transform):
    """VC1 Transform class.

    VC1 provides pre-trained ResNet weights aimed at facilitating visual
    embedding for robotic tasks. The models are trained using Ego4d.

    See the paper:
        VC1: A Universal Visual Representation for Robot Manipulation (Suraj Nair,
            Aravind Rajeswaran, Vikash Kumar, Chelsea Finn, Abhinav Gupta)
            https://arxiv.org/abs/2203.12601

    The VC1Transform is created in a lazy manner: the object will be initialized
    only when an attribute (a spec or the forward method) will be queried.
    The reason for this is that the :obj:`_init()` method requires some attributes of
    the parent environment (if any) to be accessed: by making the class lazy we
    can ensure that the following code snippet works as expected:

    Examples:
        >>> transform = VC1Transform("default", in_keys=["pixels"])
        >>> env.append_transform(transform)
        >>> # the forward method will first call _init which will look at env.observation_spec
        >>> env.reset()

    Args:
        in_keys (list of NestedKeys): list of input keys. If left empty, the
            "pixels" key is assumed.
        out_keys (list of NestedKeys, optional): list of output keys. If left empty,
             "VC1_vec" is assumed.
        model_name (str): One of ``"large"``, ``"base"`` or any other compatible
            model name (see the `github repo <https://github.com/facebookresearch/eai-vc>`_ for more info). Defaults to ``"default"``
            which provides a small, untrained model for testing.
        del_keys (bool, optional): If ``True`` (default), the input key will be
            discarded from the returned tensordict.
    """

    inplace = False
    IMPORT_ERROR = (
        "Could not load vc_models. You can install it via "
        "VC1Transform.install_vc_models()."
    )

    def __init__(self, in_keys, out_keys, model_name, del_keys: bool = True):
        if model_name == "default":
            self.make_noload_model()
            model_name = "vc1_vitb_noload"
        self.model_name = model_name
        self.del_keys = del_keys

        super().__init__(in_keys=in_keys, out_keys=out_keys)
        self._init()

    def _init(self):
        try:
            from vc_models.models.vit import model_utils
        except ModuleNotFoundError as err:
            raise ModuleNotFoundError(self.IMPORT_ERROR) from err

        if self.model_name == "base":
            model_name = model_utils.VC1_BASE_NAME
        elif self.model_name == "large":
            model_name = model_utils.VC1_LARGE_NAME
        else:
            model_name = self.model_name

        model, embd_size, model_transforms, model_info = model_utils.load_model(
            model_name
        )
        self.model = model
        self.embd_size = embd_size
        self.model_transforms = self._map_tv_to_torchrl(model_transforms)

    def _map_tv_to_torchrl(
        self,
        model_transforms,
        in_keys=None,
    ):
        if in_keys is None:
            in_keys = self.in_keys
        from torchvision import transforms

        if isinstance(model_transforms, transforms.Resize):
            size = model_transforms.size
            if isinstance(size, int):
                size = (size, size)
            return Resize(
                *size,
                in_keys=in_keys,
            )
        elif isinstance(model_transforms, transforms.CenterCrop):
            size = model_transforms.size
            if isinstance(size, int):
                size = (size,)
            return CenterCrop(
                *size,
                in_keys=in_keys,
            )
        elif isinstance(model_transforms, transforms.Normalize):
            return ObservationNorm(
                in_keys=in_keys,
                loc=torch.as_tensor(model_transforms.mean).reshape(3, 1, 1),
                scale=torch.as_tensor(model_transforms.std).reshape(3, 1, 1),
                standard_normal=True,
            )
        elif isinstance(model_transforms, transforms.ToTensor):
            return ToTensorImage(
                in_keys=in_keys,
            )
        elif isinstance(model_transforms, transforms.Compose):
            transform_list = []
            for t in model_transforms.transforms:

                if isinstance(t, transforms.ToTensor):
                    transform_list.insert(0, t)
                else:
                    transform_list.append(t)
            if len(transform_list) == 0:
                raise RuntimeError("Did not find any transform.")
            for i, t in enumerate(transform_list):
                if i == 0:
                    transform_list[i] = self._map_tv_to_torchrl(t)
                else:
                    transform_list[i] = self._map_tv_to_torchrl(t)
            return Compose(*transform_list)
        else:
            raise NotImplementedError(type(model_transforms))

    def _call(self, next_tensordict):
        if not self.del_keys:
            in_keys = [
                in_key
                for in_key, out_key in zip(self.in_keys, self.out_keys)
                if in_key != out_key
            ]
            saved_td = next_tensordict.select(*in_keys)
        with next_tensordict.view(-1) as tensordict_view:
            super()._call(self.model_transforms(tensordict_view))
        if self.del_keys:
            next_tensordict.exclude(*self.in_keys, inplace=True)
        else:
            # reset in_keys
            next_tensordict.update(saved_td)
        return next_tensordict

    forward = _call

    def _reset(
        self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
    ) -> TensorDictBase:
        # TODO: Check this makes sense
        with _set_missing_tolerance(self, True):
            tensordict_reset = self._call(tensordict_reset)
        return tensordict_reset

    @torch.no_grad()
    def _apply_transform(self, obs: torch.Tensor) -> None:
        shape = None
        if obs.ndimension() > 4:
            shape = obs.shape[:-3]
            obs = obs.flatten(0, -4)
        out = self.model(obs)
        if shape is not None:
            out = out.view(*shape, *out.shape[1:])
        return out

    def transform_observation_spec(self, observation_spec: TensorSpec) -> TensorSpec:
        if not isinstance(observation_spec, Composite):
            raise ValueError("VC1Transform can only infer Composite")

        keys = [key for key in observation_spec.keys(True, True) if key in self.in_keys]
        device = observation_spec[keys[0]].device
        dim = observation_spec[keys[0]].shape[:-3]

        observation_spec = observation_spec.clone()
        if self.del_keys:
            for in_key in keys:
                del observation_spec[in_key]

        for out_key in self.out_keys:
            observation_spec[out_key] = Unbounded(
                shape=torch.Size([*dim, self.embd_size]), device=device
            )

        return observation_spec

    def to(self, dest: DEVICE_TYPING | torch.dtype):
        if isinstance(dest, torch.dtype):
            self._dtype = dest
        else:
            self._device = dest
        return super().to(dest)

    @property
    def device(self):
        return self._device

    @property
    def dtype(self):
        return self._dtype

    @classmethod
    def install_vc_models(cls, auto_exit=False):
        try:
            from vc_models import models  # noqa: F401

            torchrl_logger.info("vc_models found, no need to install.")
        except ModuleNotFoundError:
            HOME = os.environ.get("HOME")
            vcdir = HOME + "/.cache/torchrl/eai-vc"
            parentdir = os.path.dirname(os.path.abspath(vcdir))
            os.makedirs(parentdir, exist_ok=True)
            try:
                from git import Repo
            except ModuleNotFoundError as err:
                raise ModuleNotFoundError(
                    "Could not load git. Make sure that `git` has been installed "
                    "in your virtual environment."
                ) from err
            Repo.clone_from("https://github.com/facebookresearch/eai-vc.git", vcdir)
            os.chdir(vcdir + "/vc_models")
            subprocess.call(["python", "setup.py", "develop"])
            if not auto_exit:
                input(
                    "VC1 has been successfully installed. Exit this python run and "
                    "relaunch it again. Press Enter to exit..."
                )
                exit()

    @classmethod
    def make_noload_model(cls):
        """Creates an naive model at a custom destination."""
        import vc_models

        models_filepath = os.path.dirname(os.path.abspath(vc_models.__file__))
        cfg_path = os.path.join(
            models_filepath, "conf", "model", "vc1_vitb_noload.yaml"
        )
        if os.path.exists(cfg_path):
            return
        config = """_target_: vc_models.models.load_model
model:
  _target_: vc_models.models.vit.vit.load_mae_encoder
  checkpoint_path:
  model:
    _target_: torchrl.envs.transforms.vc1._vit_base_patch16
    img_size: 224
    use_cls: True
    drop_path_rate: 0.0
transform:
  _target_: vc_models.transforms.vit_transforms
metadata:
  algo: mae
  model: vit_base_patch16
  data:
    - ego
    - imagenet
    - inav
  comment: 182_epochs
"""
        with open(cfg_path, "w") as file:
            file.write(config)


def _vit_base_patch16(**kwargs):
    from vc_models.models.vit.vit import VisionTransformer

    model = VisionTransformer(
        patch_size=16,
        embed_dim=16,
        depth=4,
        num_heads=4,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        **kwargs,
    )
    return model
